"""
Collection deletion service.

Handles:
- Full collection deletion with proper cleanup
- Documents are preserved but unlinked
- RAG index and chunks are deleted
"""

from typing import Dict, Any

from loguru import logger

from ....database.models.library import (
    Collection,
    DocumentCollection,
    DocumentChunk,
    CollectionFolder,
    RAGIndex,
    RagDocumentStatus,
)
from ....database.session_context import get_user_db_session
from ..utils.cascade_helper import CascadeHelper


class CollectionDeletionService:
    """Service for collection deletion operations."""

    def __init__(self, username: str):
        """
        Initialize collection deletion service.

        Args:
            username: Username for database session
        """
        self.username = username

    def delete_collection(
        self, collection_id: str, delete_orphaned_documents: bool = True
    ) -> Dict[str, Any]:
        """
        Delete a collection and clean up all related data.

        By default, orphaned documents (not in any other collection) are deleted.
        Set delete_orphaned_documents=False to preserve all documents.

        The following are deleted:
        - DocumentChunks for this collection
        - FAISS index files
        - RAGIndex records
        - CollectionFolder records (CASCADE)
        - DocumentCollection links (CASCADE)
        - RagDocumentStatus records (CASCADE)
        - Orphaned documents (if delete_orphaned_documents=True)

        Args:
            collection_id: ID of the collection to delete
            delete_orphaned_documents: If True, delete documents not in any
                other collection after unlinking

        Returns:
            Dict with deletion details:
            {
                "deleted": True/False,
                "collection_id": str,
                "collection_name": str,
                "chunks_deleted": int,
                "documents_unlinked": int,
                "indices_deleted": int,
                "folders_deleted": int,
                "orphaned_documents_deleted": int,
                "error": str (if failed)
            }
        """
        with get_user_db_session(self.username) as session:
            try:
                # Get collection
                collection = session.query(Collection).get(collection_id)
                if not collection:
                    return {
                        "deleted": False,
                        "collection_id": collection_id,
                        "error": "Collection not found",
                    }

                collection_name = f"collection_{collection_id}"
                result = {
                    "deleted": False,
                    "collection_id": collection_id,
                    "collection_name": collection.name,
                    "chunks_deleted": 0,
                    "documents_unlinked": 0,
                    "indices_deleted": 0,
                    "folders_deleted": 0,
                    "orphaned_documents_deleted": 0,
                }

                # 1. Get document IDs BEFORE deleting links (for orphan check)
                doc_ids_in_collection = [
                    dc.document_id
                    for dc in session.query(DocumentCollection)
                    .filter_by(collection_id=collection_id)
                    .all()
                ]
                result["documents_unlinked"] = len(doc_ids_in_collection)

                # 2. Delete DocumentChunks for this collection
                result["chunks_deleted"] = (
                    CascadeHelper.delete_collection_chunks(
                        session, collection_name
                    )
                )

                # 3. Delete RAGIndex records and FAISS files
                rag_result = CascadeHelper.delete_rag_indices_for_collection(
                    session, collection_name
                )
                result["indices_deleted"] = rag_result["deleted_indices"]

                # 4. Count folders before deletion
                result["folders_deleted"] = (
                    session.query(CollectionFolder)
                    .filter_by(collection_id=collection_id)
                    .count()
                )

                # 5. Delete DocumentCollection links explicitly before collection
                session.query(DocumentCollection).filter_by(
                    collection_id=collection_id
                ).delete(synchronize_session=False)

                # 6. Delete linked folders explicitly
                session.query(CollectionFolder).filter_by(
                    collection_id=collection_id
                ).delete(synchronize_session=False)

                # 7. Delete the collection itself
                session.delete(collection)

                # 8. Delete orphaned documents if requested
                if delete_orphaned_documents:
                    for doc_id in doc_ids_in_collection:
                        # Check if document is in any other collection
                        remaining = (
                            session.query(DocumentCollection)
                            .filter_by(document_id=doc_id)
                            .count()
                        )
                        if remaining == 0:
                            # Document is orphaned - delete it
                            CascadeHelper.delete_document_completely(
                                session, doc_id
                            )
                            result["orphaned_documents_deleted"] += 1
                            logger.info(
                                f"Deleted orphaned document {doc_id[:8]}..."
                            )

                session.commit()

                result["deleted"] = True
                logger.info(
                    f"Deleted collection {collection_id[:8]}... "
                    f"({result['collection_name']}): {result['chunks_deleted']} chunks, "
                    f"{result['documents_unlinked']} documents unlinked, "
                    f"{result['orphaned_documents_deleted']} orphaned deleted"
                )

                return result

            except Exception:
                logger.exception(f"Failed to delete collection {collection_id}")
                session.rollback()
                return {
                    "deleted": False,
                    "collection_id": collection_id,
                    "error": "Failed to delete collection",
                }

    def delete_collection_index_only(
        self, collection_id: str
    ) -> Dict[str, Any]:
        """
        Delete only the RAG index for a collection, keeping the collection itself.

        This is useful for rebuilding an index from scratch.

        Args:
            collection_id: ID of the collection

        Returns:
            Dict with deletion details
        """
        with get_user_db_session(self.username) as session:
            try:
                # Verify collection exists
                collection = session.query(Collection).get(collection_id)
                if not collection:
                    return {
                        "deleted": False,
                        "collection_id": collection_id,
                        "error": "Collection not found",
                    }

                collection_name = f"collection_{collection_id}"
                result = {
                    "deleted": False,
                    "collection_id": collection_id,
                    "chunks_deleted": 0,
                    "indices_deleted": 0,
                    "documents_reset": 0,
                }

                # 1. Delete DocumentChunks
                result["chunks_deleted"] = (
                    CascadeHelper.delete_collection_chunks(
                        session, collection_name
                    )
                )

                # 2. Delete RAGIndex records and FAISS files
                rag_result = CascadeHelper.delete_rag_indices_for_collection(
                    session, collection_name
                )
                result["indices_deleted"] = rag_result["deleted_indices"]

                # 3. Reset DocumentCollection indexed status
                result["documents_reset"] = (
                    session.query(DocumentCollection)
                    .filter_by(collection_id=collection_id)
                    .update({"indexed": False, "chunk_count": 0})
                )

                # 4. Delete RagDocumentStatus for this collection
                session.query(RagDocumentStatus).filter_by(
                    collection_id=collection_id
                ).delete(synchronize_session=False)

                # 5. Reset collection embedding info
                collection.embedding_model = None
                collection.embedding_model_type = None
                collection.embedding_dimension = None
                collection.chunk_size = None
                collection.chunk_overlap = None

                session.commit()
                result["deleted"] = True

                logger.info(
                    f"Deleted index for collection {collection_id[:8]}...: "
                    f"{result['chunks_deleted']} chunks, "
                    f"{result['documents_reset']} documents reset"
                )

                return result

            except Exception:
                logger.exception(
                    f"Failed to delete index for collection {collection_id}"
                )
                session.rollback()
                return {
                    "deleted": False,
                    "collection_id": collection_id,
                    "error": "Failed to delete collection index",
                }

    def get_deletion_preview(self, collection_id: str) -> Dict[str, Any]:
        """
        Get a preview of what will be deleted.

        Useful for showing the user what will happen before confirming.

        Args:
            collection_id: ID of the collection

        Returns:
            Dict with preview information
        """
        with get_user_db_session(self.username) as session:
            collection = session.query(Collection).get(collection_id)
            if not collection:
                return {"found": False, "collection_id": collection_id}

            collection_name = f"collection_{collection_id}"

            # Count documents
            documents_count = (
                session.query(DocumentCollection)
                .filter_by(collection_id=collection_id)
                .count()
            )

            # Count chunks
            chunks_count = (
                session.query(DocumentChunk)
                .filter_by(collection_name=collection_name)
                .count()
            )

            # Count folders
            folders_count = (
                session.query(CollectionFolder)
                .filter_by(collection_id=collection_id)
                .count()
            )

            # Check for RAG index
            has_index = (
                session.query(RAGIndex)
                .filter_by(collection_name=collection_name)
                .first()
                is not None
            )

            return {
                "found": True,
                "collection_id": collection_id,
                "name": collection.name,
                "description": collection.description,
                "is_default": collection.is_default,
                "documents_count": documents_count,
                "chunks_count": chunks_count,
                "folders_count": folders_count,
                "has_rag_index": has_index,
                "embedding_model": collection.embedding_model,
            }
